package com.kool.kmqtt.server.repository.local;

import com.alibaba.fastjson.JSON;
import com.kool.kmqtt.server.repository.subscription.SubscribeInfo;
import com.kool.kmqtt.server.repository.subscription.Subscriber;
import com.kool.kmqtt.server.repository.subscription.Subscription;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;

import static com.kool.kmqtt.util.TopicUtil.getNextItmes;
import static com.kool.kmqtt.util.TopicUtil.split;

/**
 * 订阅信息树
 * 主题过滤器和客户端id的关联信息
 * 树结构，根开始的每条路径都代表一个主题过滤器，节点的值代表订阅信息包括客户端id和授权qos
 */
public class SubscribeTrees {
    private static final SubscribeTrees self = new SubscribeTrees();

    private SubscribeTrees() {

    }

    public static SubscribeTrees getInstance() {
        return self;
    }

    private ConcurrentHashMap<String, Node> trees = new ConcurrentHashMap<>();

    /**
     * 查询所有订阅信息
     *
     * @return
     */
    public List<Subscription> getAllCopy() {
        List<Subscription> subscriptions = new ArrayList<>();
        for (String topicItem : this.trees.keySet()) {
            Node treeRoot = this.trees.get(topicItem);
            ConcurrentHashMap<String, Node> children = treeRoot.children;
            ConcurrentHashMap<String, Subscriber> subscribers = treeRoot.subscribers;
            for (Subscriber subscriber : subscribers.values()) {
                Subscription subscription = new Subscription();
                subscription.setTopicFilter(topicItem);
                subscription.setClientId(subscriber.getClientId());
                subscription.setQos(subscriber.getQos());
                subscriptions.add(subscription);
            }
            List<Subscription> subtreeSubscriptions = scanTrees(topicItem, children);
            subscriptions.addAll(subtreeSubscriptions);
        }
        return subscriptions;
    }

    /**
     * 扫描森林，遍历订阅信息
     *
     * @param topicFilterPrefix
     * @param trees
     * @return
     */
    private List<Subscription> scanTrees(String topicFilterPrefix, ConcurrentHashMap<String, Node> trees) {
        List<Subscription> subscriptions = new ArrayList<>();
        for (String topicItem : trees.keySet()) {
            Node treeRoot = trees.get(topicItem);
            ConcurrentHashMap<String, Node> children = treeRoot.children;
            ConcurrentHashMap<String, Subscriber> subscribers = treeRoot.subscribers;
            for (Subscriber subscriber : subscribers.values()) {
                Subscription subscription = new Subscription();
                subscription.setTopicFilter(topicFilterPrefix + "/" + topicItem);
                subscription.setClientId(subscriber.getClientId());
                subscription.setQos(subscriber.getQos());
                subscriptions.add(subscription);
            }
            List<Subscription> subtreeSubscriptions = scanTrees(topicItem, children);
            subscriptions.addAll(subtreeSubscriptions);

        }
        return subscriptions;
    }

    private class Node {
        /**
         * key:主题过滤器分割后的一段层级名
         * value:订阅信息
         */
        private ConcurrentHashMap<String, Node> children = new ConcurrentHashMap<>();
        /**
         * 客户端集合
         * key:clientId
         * value:客户端的授权qos
         */
        private ConcurrentHashMap<String, Subscriber> subscribers = new ConcurrentHashMap<>();
    }

    /**
     * 增加订阅信息
     *
     * @param clientId
     * @param subscribeInfo
     */
    public void add(String clientId, SubscribeInfo subscribeInfo) {
        add(subscribeInfo.getTopicFilter(), clientId, subscribeInfo.getQos());
    }

    /**
     * 增加订阅信息
     *
     * @param topicFilter
     * @param clientId
     * @param qos
     */
    public void add(String topicFilter, String clientId, int qos) {
        //分割主题过滤器
        String[] topicFilterItems = split(topicFilter);

        //逐层匹配，如果不存在则插入，如果路径完整匹配，则更新
        updateNode(trees, topicFilterItems, clientId, qos);
    }

    private void updateNode(ConcurrentHashMap<String, Node> children, String[] topicFilterItems, String clientId, int qos) {
        //判断是否存在
        Node node = children.get(topicFilterItems[0]);
        if (node == null) {
            //如果不存在，则需要插入
            String[] nextItems = getNextItmes(topicFilterItems);
            children.put(topicFilterItems[0], buildNode(nextItems, clientId, qos));
        } else {
            String[] nextItems = getNextItmes(topicFilterItems);
            if (nextItems == null) {
                //递归结束条件
                Subscriber subscriber = new Subscriber();
                subscriber.setClientId(clientId);
                subscriber.setQos(qos);
                node.subscribers.put(clientId, subscriber);
            } else {
                updateNode(node.children, nextItems, clientId, qos);
            }
        }

    }

    private Node buildNode(String[] topicFilterItems, String clientId, int qos) {
        Node node = new Node();
        if (topicFilterItems == null) {
            //递归结束条件
            node.subscribers = new ConcurrentHashMap<>();
            Subscriber subscriber = new Subscriber();
            subscriber.setClientId(clientId);
            subscriber.setQos(qos);
            node.subscribers.put(clientId, subscriber);
        } else {
            String[] nextItems = getNextItmes(topicFilterItems);
            node.children.put(topicFilterItems[0], buildNode(nextItems, clientId, qos));
        }
        return node;
    }

    /**
     * 删除主题过滤器的订阅
     *
     * @param topicFilter
     * @param clientId
     */
    public void delete(String topicFilter, String clientId) {
        //分割主题过滤器
        String[] topicFilterItems = split(topicFilter);
        //递归扫描，删除订阅信息
        deleteScan(trees, topicFilterItems, clientId);
    }

    private void deleteScan(ConcurrentHashMap<String, Node> children, String[] topicFilterItems, String clientId) {
        if (topicFilterItems == null || children.get(topicFilterItems[0]) == null) {
            return;
        }
        if (topicFilterItems.length == 1) {
            children.get(topicFilterItems[0]).subscribers.remove(clientId);
        } else {
            deleteScan(children.get(topicFilterItems[0]).children, getNextItmes(topicFilterItems), clientId);
        }
    }

    /**
     * 删除客户端的所有订阅信息
     *
     * @param clientId
     */
    public void deleteSubscriber(String clientId) {
        deleteSubscriberScan(trees, clientId);
    }

    private void deleteSubscriberScan(ConcurrentHashMap<String, Node> children, String clientId) {
        if (children.size() == 0) {
            return;
        }
        for (Node node : children.values()) {
            node.subscribers.remove(clientId);
            deleteSubscriberScan(node.children, clientId);
        }
    }

    /**
     * 查询与主题匹配的主题过滤器的客户端
     * 客户端使用带通配符的主题过滤器请求订阅时，客户端的订阅可能会重复，因此发布的消息可能会匹配多个过滤器。对于这种情况，服务端必须将消息分发给所有订阅匹配的QoS等级最高的客户端
     *
     * @param topic
     * @return
     */
    public List<Subscriber> match(String topic) {
        //分割主题
        String[] topicItems = split(topic);
        //逐层扫描
        HashMap<String, Subscriber> subscribers = new HashMap<>();
        matchScan(trees, topicItems, subscribers);
        return new ArrayList<>(subscribers.values());
    }

    private void matchScan(ConcurrentHashMap<String, Node> children, String[] topicItems, HashMap<String, Subscriber> subscribers) {
        if (topicItems == null || children.size() == 0) {
            return;
        }
        for (String topicFilterItem : children.keySet()) {
            //如果是匹配的，增加到返回结果中
            //匹配：topicFilterItem=#，或者topicFilterItem=+且topicItems长度=1，或者topicFilterItem=topicItems[0]且topicItems长度=1
            if (topicFilterItem.equals("#")
                    || (topicItems.length == 1 && topicFilterItem.equals("+"))
                    || (topicItems.length == 1 && topicFilterItem.equals(topicItems[0]))) {
                Node node = children.get(topicFilterItem);
                if (node != null) {
                    for (Subscriber subscriber : node.subscribers.values()) {
                        if ((!subscribers.containsKey(subscriber.getClientId()))
                                || (subscribers.containsKey(subscriber.getClientId())
                                && subscribers.get(subscriber.getClientId()).getQos() < subscriber.getQos())) {
                            //如果扫描到相同的客户端不同的QoS，取大的QoS
                            subscribers.put(subscriber.getClientId(), subscriber);
                        }
                    }
                    //如果下一层级有#，将#节点的订阅信息增加到返回结果
                    if (node.children.containsKey("#")) {
                        subscribers.putAll(node.children.get("#").subscribers);
                    }
                }
            } else if (topicItems.length > 1 && (topicFilterItem.equals(topicItems[0])
                    || topicFilterItem.equals("+"))) {
                //如果是层级匹配，扫描子树
                matchScan(children.get(topicFilterItem).children, getNextItmes(topicItems), subscribers);
            }
        }
    }

    public static void main(String[] args) {
        SubscribeTrees.getInstance().add("a/#", "001", 0);
        SubscribeTrees.getInstance().add("b/+", "001", 0);
        SubscribeTrees.getInstance().add("c/1/#", "001", 0);
        SubscribeTrees.getInstance().add("c/2/+", "001", 0);
        SubscribeTrees.getInstance().add("c/2/+/1", "001", 0);
        SubscribeTrees.getInstance().add("d/1/2//3", "001", 0);


        SubscribeTrees.getInstance().add("d/#", "002", 0);

        List<Subscriber> subscribers = SubscribeTrees.getInstance().match("a/1");
        System.out.println("a/1:" + JSON.toJSONString(subscribers));
        subscribers = SubscribeTrees.getInstance().match("a");
        System.out.println("a:" + JSON.toJSONString(subscribers));
        subscribers = SubscribeTrees.getInstance().match("b/1");
        System.out.println("b/1:" + JSON.toJSONString(subscribers));
        subscribers = SubscribeTrees.getInstance().match("c/1");
        System.out.println("c/1:" + JSON.toJSONString(subscribers));
        subscribers = SubscribeTrees.getInstance().match("c/2");
        System.out.println("c/2:" + JSON.toJSONString(subscribers));
        subscribers = SubscribeTrees.getInstance().match("c/2/3");
        System.out.println("c/2/3:" + JSON.toJSONString(subscribers));
        subscribers = SubscribeTrees.getInstance().match("c/2/3/4");
        System.out.println("c/2/3/4:" + JSON.toJSONString(subscribers));
        subscribers = SubscribeTrees.getInstance().match("c/2/3/1");
        System.out.println("c/2/3/1:" + JSON.toJSONString(subscribers));
        subscribers = SubscribeTrees.getInstance().match("d/1/2//3");
        System.out.println("d/1/2/3:" + JSON.toJSONString(subscribers));
    }
}
